from matplotlib import pyplot as plt
import matplotlib.patches as mpatches

import numpy as np

# Affichage de la table
def AfficheTable(L, g):
    sommets = list(g.keys())
    n = len(sommets)

    # Trouver le i max dans L
    i_max = max(k[0] for k in L.keys()) if L else 0

    # Créer une matrice RGB
    mat = np.zeros((i_max + 1, n, 3))

    for i in range(i_max + 1):
        for idx, v in enumerate(sommets):
            if (i, v) in L:
                mat[i][idx] = [0.2, 0.7, 0.3]  # Vert
            else:
                mat[i][idx] = [0.85, 0.85, 0.85]  # Gris clair

    plt.close('all')
    # Taille de figure plus généreuse
    fig, ax = plt.subplots(figsize=(max(6, n * 1.0), max(4, (i_max + 1) * 0.7)))
    ax.imshow(mat)

    # Afficher les valeurs dans chaque case
    for i in range(i_max + 1):
        for idx, v in enumerate(sommets):
            if (i, v) in L:
                valeur = L[(i, v)]
                if valeur == float('inf'):
                    txt = '∞'
                else:
                    txt = str(int(valeur))
                ax.text(idx, i, txt, ha='center', va='center',
                        color='white', fontsize=10, fontweight='bold')

    # Configurer les axes avec les noms des sommets
    ax.set_xticks(range(n))
    ax.set_xticklabels(sommets, fontsize=10, style='italic')
    ax.set_yticks(range(i_max + 1))
    ax.set_yticklabels(range(i_max + 1), fontsize=10)

    # Quadrillage
    ax.set_xticks(np.arange(-0.5, n, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, i_max + 1, 1), minor=True)
    ax.grid(which='minor', color='black', linewidth=0.5)

    plt.xlabel('Sommets', fontsize=11)
    plt.ylabel('Nombre max. d\'arêtes (i)', fontsize=11)
    plt.title('Table de programmation dynamique', fontsize=12)

    # Ajuster les marges manuellement au lieu de tight_layout
    plt.subplots_adjust(left=0.15, right=0.95, top=0.9, bottom=0.15)
    plt.show()

# Graphe représenté par un dictionnaire d'adjacence
# graphe[u] = [(v1, poids1), (v2, poids2), ...]
graphe = {
    'S': [('U', 2), ('V', 4)],
    'U': [('V', -1), ('W', 2)],
    'V': [('W', 3), ('T', 4)],
    'W': [('T', 2)],
    'T': []
}
source = 'S'
L = {}  # Dictionnaire de mémoïsation


###########################################
# Approche bottom-up
###########################################

def initialiser_cas_de_base(G,S,L):
    for u in G:
        L[(0,u)] = np.inf
    L[(0,S)] = 0
    return L

def obtenir_aretes_entrantes(G,v):
    liste = []

    # Parcourt les sommets dans les valeurs du dictionnaire du graphe
    # pour trouver le sommet v
    for cle, val in G.items():
        # val est de la forme [('X',poids), ('Y',poids), ...]
        for sommet,poids in val:
            if sommet == v:
                liste.append((cle,poids))

    return liste

def remplir_table(G, L):
    # Nombre de sommets
    n = len(G)

    for i in range(1,n+1):
        for u in G.keys():
            # Cas n°1
            L[(i,u)] = L[(i-1,u)]

            # Sous-cas du cas n°2
            for sommet_entrant in obtenir_aretes_entrantes(G,u):
                # sommet_entrant est de la forme ('X',poids)
                L[(i,u)] = min(L[(i,u)],L[(i-1,sommet_entrant[0])] + sommet_entrant[1])

    return L

# Question 6 (théorique)
# - Nombre de sous-problèmes : (n+1) * n = O(n²)
# - Complexité temporelle : O(n * m) car pour chaque sous-problème on examine les arêtes entrantes
# - Complexité spatiale : O(n²) pour stocker tous les sous-problèmes

def bellman_ford_bottomup(G, S):
    # Nombre de sommets
    n = len(G)

    # Initialisation de la table
    L = {}
    initialiser_cas_de_base(G,S,L)

    # Remplissage de la table jusqu'à n
    L = remplir_table(G,L)

    # Détection cycle négatif
    cycle_negatif = False

    for sommet in G:
        if L[(n,sommet)] != L[(n-1,sommet)]:
            cycle_negatif = True

    return L,cycle_negatif

def extraire_distances(L, G):
    n = len(G)
    distances = {cle:L[(n,cle)] for cle in G.keys()}
    return distances

L = initialiser_cas_de_base(graphe,source,L)
print(L)
print(obtenir_aretes_entrantes(graphe, 'V'))
print(obtenir_aretes_entrantes(graphe, 'T'))
print(obtenir_aretes_entrantes(graphe, 'S'))
L = remplir_table(graphe,L)

# Pas de cycle négatif car L[5,v] = L[4,v]
# pour tous les sommets v
AfficheTable(L,graphe)

print(bellman_ford_bottomup(graphe, source))


# Graphe avec cycle négatif
graphe_neg = {
    'S': [('U', 2), ('V', 4)],
    'U': [('V', -1), ('W', 2)],
    'V': [('W', -3), ('T', -4)],
    'W': [('T', -2)],
    'T': [('V',-1)]
}
source = 'S'
L, cycle_negatif = bellman_ford_bottomup(graphe_neg, source)
AfficheTable(L,graphe_neg)
print(cycle_negatif)


source = 'S'
L, cycle_negatif = bellman_ford_bottomup(graphe, source)

print(extraire_distances(L,graphe))



##############################
# Approche top-down
##############################

def rec_bellman_ford1(G,S,dest):
    L = {}
    n = len(G)

    def f_rec(i,v):
        # Utilise la mémoisation
        if (i,v) in L:
            return L[(i,v)]

        # Cas de base
        if i == 0:
            if v == S:
                L[(i,v)] = 0
            else:
                L[(i,v)] = np.inf
            return L[(i,v)]

        # Cas n°1 : hériter de la valeur précédente
        val_opt = f_rec(i-1,v)

        # Cas n°2 : Tester tous les prédécesseurs
        for (u,poids) in obtenir_aretes_entrantes(G,v):
            candidat = f_rec(i-1,u) + poids
            val_opt = min(val_opt,candidat)

        # Mémoisation
        L[(i,v)] = val_opt

        return L[(i,v)]

    distance = f_rec(n-1,dest)
    return distance, L

# Question 3 (théorique)
# Réponse attendue :
# - Complexité temporelle (pire cas) : O(n * m) - même que bottom-up car dans le pire cas
#   tous les sous-problèmes sont calculés
# - Complexité spatiale : O(n²) pour le dictionnaire + O(n) pour la pile d'appels
#   Total : O(n²) (le dictionnaire domine)

def rec_bellman_ford2(G,S):
    L = {}
    n = len(G)
    distances = {}

    def f_rec(i,v):
        # Utilise la mémoisation
        if (i,v) in L:
            return L[(i,v)]

        # Cas de base
        if i == 0:
            if v == S:
                L[(i,v)] = 0
            else:
                L[(i,v)] = np.inf
            return L[(i,v)]

        # Cas n°1 : hériter de la valeur précédente
        val_opt = f_rec(i-1,v)

        # Cas n°2 : Tester tous les prédécesseurs
        for (u,poids) in obtenir_aretes_entrantes(G,v):
            candidat = f_rec(i-1,u) + poids
            val_opt = min(val_opt,candidat)

        # Mémoisation
        L[(i,v)] = val_opt

        return L[(i,v)]

    # Calcul des distances vers tous les sommets
    for sommet in G:
        distances[sommet] = f_rec(n-1,sommet)

    # Appel pour détecter un cycle négatif
    cycle_negatif = False
    for sommet in G:
        test = f_rec(n,sommet)
        if test < distances[sommet]:
            cycle_negatif = True

    return distances, cycle_negatif, L


distance, L = rec_bellman_ford1(graphe,source,'T')
print(distance)
AfficheTable(L,graphe)

distances, cycle_negatif, L = rec_bellman_ford2(graphe,source)
print(distances,cycle_negatif)


#################
# Reconstruction
#################

def determiner_choix(G, L, i, v):
    n = len(G)

    # Test du cas n°1
    if L[(i,v)] == L[(i-1,v)]:
        return ("HERITER",i-1,v)

    # Recherche le prédécesseur du cas n°2
    for sommet,poids in obtenir_aretes_entrantes(G,v):
        if L[(i,v)] == L[(i-1,sommet)] + poids:
            return ("ARETE " + sommet + "->" + v,i-1,sommet)


def reconstruire_chemin(G,L,source, destination):
    i = len(G)
    chemin = [destination]

    while destination != source:
        action, i, destination = determiner_choix(G,L,i,destination)
        if action != "HERITER":
            chemin.append(destination)
    return [chemin[i] for i in range(len(chemin)-1,-1,-1)]


distances, cycle_negatif, L = rec_bellman_ford2(graphe,source)
chemin = reconstruire_chemin(graphe,L,source,'T')
print(chemin)


























